{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7765UFHoyGx6"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "KsOkK8O69PyT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RKQpW0JqQQmY"
      },
      "source": [
        "# 使用 Tensorflow Lattice 实现形状约束\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r61fkA2i9Y3_"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/lattice/tutorials/shape_constraints\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/shape_constraints.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行 </a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/lattice/tutorials/shape_constraints.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 中查看源代码</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/lattice/tutorials/shape_constraints.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\"> 下载笔记本</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2plcL3iTVjsp"
      },
      "source": [
        "## 概述\n",
        "\n",
        "本教程概述了 TensorFlow Lattice (TFL) 库提供的约束和正则化器。我们将在合成数据集上使用 TFL Canned Estimator，但请注意，本教程中的所有内容也可以使用通过 TFL Keras 层构造的模型来完成。\n",
        "\n",
        "在继续之前，请确保您的运行时已安装所有必需的软件包（如下方代码单元中导入的软件包）。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x769lI12IZXB"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbBVAR6UeRN5"
      },
      "source": [
        "安装 TF Lattice 软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bpXjJKpSd3j4"
      },
      "outputs": [],
      "source": [
        "#@test {\"skip\": true}\n",
        "!pip install tensorflow-lattice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jSVl9SHTeSGX"
      },
      "source": [
        "导入所需的软件包："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "iY6awAl058TV"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "\n",
        "from IPython.core.pylabtools import figsize\n",
        "import itertools\n",
        "import logging\n",
        "import matplotlib\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import sys\n",
        "import tensorflow_lattice as tfl\n",
        "logging.disable(sys.maxsize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7TmBk_IGgJF0"
      },
      "source": [
        "本指南中使用的默认值："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kQHPyPsPUF92"
      },
      "outputs": [],
      "source": [
        "NUM_EPOCHS = 500\n",
        "BATCH_SIZE = 64\n",
        "LEARNING_RATE=0.001"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FjR7D8Ag3z0d"
      },
      "source": [
        "## 餐厅评分的训练数据集"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a1YetzbdFOij"
      },
      "source": [
        "设想一个简化的场景，我们要确定用户是否会点击餐厅搜索结果。任务是基于给定的输入特征预测点击率 (CTR)：\n",
        "\n",
        "- 平均评分 (`avg_rating`)：一个数字特征，值在 [1,5] 区间内。\n",
        "- 评论数 (`num_reviews`)：一个数字特征，上限值为 200，我们使用该值作为衡量餐厅热度的指标。\n",
        "- 美元评分 (`dollar_rating`)：一个分类特征，其字符串值位于集合 {\"D\", \"DD\", \"DDD\", \"DDDD\"} 内。\n",
        "\n",
        "我们创建一个合成数据集，其真实 CTR 值由以下公式得出：$$ CTR = 1 / (1 + exp{\\mbox{b(dollar_rating)}-\\mbox{avg_rating}\\times log(\\mbox{num_reviews}) /4 }) $$，其中 $b(\\cdot)$ 可将每个 `dollar_rating` 转换为基准值：$$ \\mbox{D}\\to 3，\\ \\mbox{DD}\\to 2，\\ \\mbox{DDD}\\to 4，\\ \\mbox{DDDD}\\to 4.5。$$\n",
        "\n",
        "此公式反映了典型的用户模式。例如，在所有其他条件固定的情况下，用户更喜欢星级较高的餐厅，\"\\$\\$\" 餐厅的点击数高于 \"\\$\" 餐厅，低于 \"\\$\\$\\$\" 和 \"\\$\\$\\$\\$\" 餐厅。 "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mKovnyv1jATw"
      },
      "outputs": [],
      "source": [
        "def click_through_rate(avg_ratings, num_reviews, dollar_ratings):\n",
        "  dollar_rating_baseline = {\"D\": 3, \"DD\": 2, \"DDD\": 4, \"DDDD\": 4.5}\n",
        "  return 1 / (1 + np.exp(\n",
        "      np.array([dollar_rating_baseline[d] for d in dollar_ratings]) -\n",
        "      avg_ratings * np.log1p(num_reviews) / 4))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BPlgRdt6jAbP"
      },
      "source": [
        "让我们看一下此 CTR 函数的等高线图。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KC5qX_XKmc7g"
      },
      "outputs": [],
      "source": [
        "def color_bar():\n",
        "  bar = matplotlib.cm.ScalarMappable(\n",
        "      norm=matplotlib.colors.Normalize(0, 1, True),\n",
        "      cmap=\"viridis\",\n",
        "  )\n",
        "  bar.set_array([0, 1])\n",
        "  return bar\n",
        "\n",
        "\n",
        "def plot_fns(fns, split_by_dollar=False, res=25):\n",
        "  \"\"\"Generates contour plots for a list of (name, fn) functions.\"\"\"\n",
        "  num_reviews, avg_ratings = np.meshgrid(\n",
        "      np.linspace(0, 200, num=res),\n",
        "      np.linspace(1, 5, num=res),\n",
        "  )\n",
        "  if split_by_dollar:\n",
        "    dollar_rating_splits = [\"D\", \"DD\", \"DDD\", \"DDDD\"]\n",
        "  else:\n",
        "    dollar_rating_splits = [None]\n",
        "  if len(fns) == 1:\n",
        "    fig, axes = plt.subplots(2, 2, sharey=True, tight_layout=False)\n",
        "  else:\n",
        "    fig, axes = plt.subplots(\n",
        "        len(dollar_rating_splits), len(fns), sharey=True, tight_layout=False)\n",
        "  axes = axes.flatten()\n",
        "  axes_index = 0\n",
        "  for dollar_rating_split in dollar_rating_splits:\n",
        "    for title, fn in fns:\n",
        "      if dollar_rating_split is not None:\n",
        "        dollar_ratings = np.repeat(dollar_rating_split, res**2)\n",
        "        values = fn(avg_ratings.flatten(), num_reviews.flatten(),\n",
        "                    dollar_ratings)\n",
        "        title = \"{}: dollar_rating={}\".format(title, dollar_rating_split)\n",
        "      else:\n",
        "        values = fn(avg_ratings.flatten(), num_reviews.flatten())\n",
        "      subplot = axes[axes_index]\n",
        "      axes_index += 1\n",
        "      subplot.contourf(\n",
        "          avg_ratings,\n",
        "          num_reviews,\n",
        "          np.reshape(values, (res, res)),\n",
        "          vmin=0,\n",
        "          vmax=1)\n",
        "      subplot.title.set_text(title)\n",
        "      subplot.set(xlabel=\"Average Rating\")\n",
        "      subplot.set(ylabel=\"Number of Reviews\")\n",
        "      subplot.set(xlim=(1, 5))\n",
        "\n",
        "  _ = fig.colorbar(color_bar(), cax=fig.add_axes([0.95, 0.2, 0.01, 0.6]))\n",
        "\n",
        "\n",
        "figsize(11, 11)\n",
        "plot_fns([(\"CTR\", click_through_rate)], split_by_dollar=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ol91olp3muNN"
      },
      "source": [
        "### 准备数据\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H8BOshZS9xwn"
      },
      "source": [
        "现在，我们需要创建合成数据集。我们首先生成餐厅及其特征的模拟数据集。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MhqcOPdTT_wj"
      },
      "outputs": [],
      "source": [
        "def sample_restaurants(n):\n",
        "  avg_ratings = np.random.uniform(1.0, 5.0, n)\n",
        "  num_reviews = np.round(np.exp(np.random.uniform(0.0, np.log(200), n)))\n",
        "  dollar_ratings = np.random.choice([\"D\", \"DD\", \"DDD\", \"DDDD\"], n)\n",
        "  ctr_labels = click_through_rate(avg_ratings, num_reviews, dollar_ratings)\n",
        "  return avg_ratings, num_reviews, dollar_ratings, ctr_labels\n",
        "\n",
        "\n",
        "np.random.seed(42)\n",
        "avg_ratings, num_reviews, dollar_ratings, ctr_labels = sample_restaurants(2000)\n",
        "\n",
        "figsize(5, 5)\n",
        "fig, axs = plt.subplots(1, 1, sharey=False, tight_layout=False)\n",
        "for rating, marker in [(\"D\", \"o\"), (\"DD\", \"^\"), (\"DDD\", \"+\"), (\"DDDD\", \"x\")]:\n",
        "  plt.scatter(\n",
        "      x=avg_ratings[np.where(dollar_ratings == rating)],\n",
        "      y=num_reviews[np.where(dollar_ratings == rating)],\n",
        "      c=ctr_labels[np.where(dollar_ratings == rating)],\n",
        "      vmin=0,\n",
        "      vmax=1,\n",
        "      marker=marker,\n",
        "      label=rating)\n",
        "plt.xlabel(\"Average Rating\")\n",
        "plt.ylabel(\"Number of Reviews\")\n",
        "plt.legend()\n",
        "plt.xlim((1, 5))\n",
        "plt.title(\"Distribution of restaurants\")\n",
        "_ = fig.colorbar(color_bar(), cax=fig.add_axes([0.95, 0.2, 0.01, 0.6]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tRetsfLv_JSR"
      },
      "source": [
        "我们来生成训练、验证和测试数据集。当用户查看搜索结果中的餐厅时，我们可以记录用户的参与（点击或不点击）作为样本点。\n",
        "\n",
        "实际上，用户通常不会浏览所有搜索结果。这意味着用户可能只会看到被当前所用排名模型评为“好”的餐厅。这样一来，在训练数据集中，“好”餐厅将获得更频繁的关注并存在过度表示的情况。使用更多特征时，训练数据集在特征空间的“差”部分可能会存在较大空缺。\n",
        "\n",
        "使用该模型进行评分时，通常会基于分布更均衡的所有相关结果进行评估，而训练数据集并不能很好地提供这种分布。在这种情况下，灵活且复杂的模型可能会因过拟合存在过度表达情况的数据点而缺乏泛化能力。我们通过运用领域知识来添加*形状约束*，在模型无法从训练数据集获取充分信息的情况下引导模型做出合理的预测，从而解决这一问题。\n",
        "\n",
        "在本例中，训练数据集主要由用户与热门餐厅的交互组成。测试数据集采用均衡的分布来模拟上述评估环境。请注意，此类测试数据集在实际问题环境中将不可用。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jS6WOtXQ8jwX"
      },
      "outputs": [],
      "source": [
        "def sample_dataset(n, testing_set):\n",
        "  (avg_ratings, num_reviews, dollar_ratings, ctr_labels) = sample_restaurants(n)\n",
        "  if testing_set:\n",
        "    # Testing has a more uniform distribution over all restaurants.\n",
        "    num_views = np.random.poisson(lam=3, size=n)\n",
        "  else:\n",
        "    # Training/validation datasets have more views on popular restaurants.\n",
        "    num_views = np.random.poisson(lam=ctr_labels * num_reviews / 50.0, size=n)\n",
        "\n",
        "  return pd.DataFrame({\n",
        "      \"avg_rating\": np.repeat(avg_ratings, num_views),\n",
        "      \"num_reviews\": np.repeat(num_reviews, num_views),\n",
        "      \"dollar_rating\": np.repeat(dollar_ratings, num_views),\n",
        "      \"clicked\": np.random.binomial(n=1, p=np.repeat(ctr_labels, num_views))\n",
        "  })\n",
        "\n",
        "\n",
        "# Generate datasets.\n",
        "np.random.seed(42)\n",
        "data_train = sample_dataset(2000, testing_set=False)\n",
        "data_val = sample_dataset(1000, testing_set=False)\n",
        "data_test = sample_dataset(1000, testing_set=True)\n",
        "\n",
        "# Plotting dataset densities.\n",
        "figsize(12, 5)\n",
        "fig, axs = plt.subplots(1, 2, sharey=False, tight_layout=False)\n",
        "for ax, data, title in [(axs[0], data_train, \"training\"),\n",
        "                        (axs[1], data_test, \"testing\")]:\n",
        "  _, _, _, density = ax.hist2d(\n",
        "      x=data[\"avg_rating\"],\n",
        "      y=data[\"num_reviews\"],\n",
        "      bins=(np.linspace(1, 5, num=21), np.linspace(0, 200, num=21)),\n",
        "      density=True,\n",
        "      cmap=\"Blues\",\n",
        "  )\n",
        "  ax.set(xlim=(1, 5))\n",
        "  ax.set(ylim=(0, 200))\n",
        "  ax.set(xlabel=\"Average Rating\")\n",
        "  ax.set(ylabel=\"Number of Reviews\")\n",
        "  ax.title.set_text(\"Density of {} examples\".format(title))\n",
        "  _ = fig.colorbar(density, ax=ax)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4fVyLgpCT1nW"
      },
      "source": [
        "定义用于训练和评估的 input_fn："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DYzRTRR2GKoS"
      },
      "outputs": [],
      "source": [
        "train_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_train,\n",
        "    y=data_train[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=NUM_EPOCHS,\n",
        "    shuffle=False,\n",
        ")\n",
        "\n",
        "# feature_analysis_input_fn is used for TF Lattice estimators.\n",
        "feature_analysis_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_train,\n",
        "    y=data_train[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    shuffle=False,\n",
        ")\n",
        "\n",
        "val_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_val,\n",
        "    y=data_val[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    shuffle=False,\n",
        ")\n",
        "\n",
        "test_input_fn = tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "    x=data_test,\n",
        "    y=data_test[\"clicked\"],\n",
        "    batch_size=BATCH_SIZE,\n",
        "    num_epochs=1,\n",
        "    shuffle=False,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qoTrw3FZqvPK"
      },
      "source": [
        "## 拟合梯度提升树"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZklNowexE3wB"
      },
      "source": [
        "让我们从两个函数开始：`avg_rating` 和 `num_reviews`。\n",
        "\n",
        "我们创建一些辅助函数，用于绘制和计算验证与测试指标。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SX6rARJWURWl"
      },
      "outputs": [],
      "source": [
        "def analyze_two_d_estimator(estimator, name):\n",
        "  # Extract validation metrics.\n",
        "  metric = estimator.evaluate(input_fn=val_input_fn)\n",
        "  print(\"Validation AUC: {}\".format(metric[\"auc\"]))\n",
        "  metric = estimator.evaluate(input_fn=test_input_fn)\n",
        "  print(\"Testing AUC: {}\".format(metric[\"auc\"]))\n",
        "\n",
        "  def two_d_pred(avg_ratings, num_reviews):\n",
        "    results = estimator.predict(\n",
        "        tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "            x=pd.DataFrame({\n",
        "                \"avg_rating\": avg_ratings,\n",
        "                \"num_reviews\": num_reviews,\n",
        "            }),\n",
        "            shuffle=False,\n",
        "        ))\n",
        "    return [x[\"logistic\"][0] for x in results]\n",
        "\n",
        "  def two_d_click_through_rate(avg_ratings, num_reviews):\n",
        "    return np.mean([\n",
        "        click_through_rate(avg_ratings, num_reviews,\n",
        "                           np.repeat(d, len(avg_ratings)))\n",
        "        for d in [\"D\", \"DD\", \"DDD\", \"DDDD\"]\n",
        "    ],\n",
        "                   axis=0)\n",
        "\n",
        "  figsize(11, 5)\n",
        "  plot_fns([(\"{} Estimated CTR\".format(name), two_d_pred),\n",
        "            (\"CTR\", two_d_click_through_rate)],\n",
        "           split_by_dollar=False)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JVef4f8yUUbs"
      },
      "source": [
        "我们可以基于数据集拟合 TensorFlow 梯度提升决策树："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DnPYlRAo2mnQ"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "gbt_estimator = tf.estimator.BoostedTreesClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    # Hyper-params optimized on validation set.\n",
        "    n_batches_per_layer=1,\n",
        "    max_depth=3,\n",
        "    n_trees=20,\n",
        "    min_node_weight=0.1,\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "gbt_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(gbt_estimator, \"GBT\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYZtd6YvsNdn"
      },
      "source": [
        "即使该模型已捕获到真实 CTR 的总体形状并具有不错的验证指标，但它在输入空间的多个部分上都呈现出违背直觉的行为：估算的 CTR 随着平均评分或评论数量的增加而降低。这是由于训练数据集无法充分覆盖到的区域缺少样本点。该模型根本无法仅基于数据推断出正确的行为。\n",
        "\n",
        "为了解决这个问题，我们强制应用了形状约束，使模型必须输出相对于平均评分和评论数量单调递增的值。稍后我们将展示如何在 TFL 中实现这一方案。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uf7WqGooFiEp"
      },
      "source": [
        "## 拟合 DNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_s2aT3x0E_tF"
      },
      "source": [
        "我们可以对 DNN 分类器重复相同的步骤。我们可以观察到类似的模式：样本点不够和评论数量过低就无法进行有效的推断。请注意，尽管验证指标优于梯度提升树解决方案，但测试指标却要差得多。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gFUeG6kLDNhO"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "dnn_estimator = tf.estimator.DNNClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    # Hyper-params optimized on validation set.\n",
        "    hidden_units=[16, 8, 8],\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "dnn_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(dnn_estimator, \"DNN\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Avkw-okw7JL"
      },
      "source": [
        "## 形状约束"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3ExyethCFBrP"
      },
      "source": [
        "TensorFlow Lattice (TFL) 侧重于强制应用形状约束，以保护超出训练数据的模型行为。这些形状约束应用于 TFL Keras 层。相关详细信息可参见[我们的 JMLR 论文](http://jmlr.org/papers/volume17/15-243/15-243.pdf)。\n",
        "\n",
        "在本教程中，我们使用 TF Canned Estimator 覆盖各种形状约束，但是请注意，所有这些步骤都可以使用通过 TFL Keras 层创建的模型来完成。\n",
        "\n",
        "与任何其他 TensorFlow Estimator 一样，TFL Canned Estimator 使用[特征列](https://tensorflow.google.cn/api_docs/python/tf/feature_column)定义输入格式，并使用训练 input_fn 传入数据。使用 TFL Canned Estimator 还需要：\n",
        "\n",
        "- *模型配置*：定义模型架构以及按特征的形状约束和正则化器。\n",
        "- *特征分析 input_fn*：传递数据以供 TFL 初始化的 TF input_fn。\n",
        "\n",
        "有关更详尽的介绍，请参阅 Canned Estimator 教程或 API 文档。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "anyCM4sCpOSo"
      },
      "source": [
        "### 单调性\n",
        "\n",
        "我们首先通过向两个特征添加单调性形状约束来解决单调性问题。\n",
        "\n",
        "为了命令 TFL 强制应用形状约束，我们在*特征配置*中指定约束。以下代码展示了如何通过设置 `monotonicity=\"increasing\"` 来要求输出相对于 `num_reviews` 和 `avg_rating` 单调递增。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FCm1lOjmwur_"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ubNRBCWW5wQ9"
      },
      "source": [
        "使用 `CalibratedLatticeConfig` 可以创建一个封装分类器，先对每个输入（数字特征的分段线性函数）应用*校准器*，然后应用*点阵*层以非线性方式融合校准的特征。我们可以使用 `tfl.visualization` 可视化模型。特别是，下图展示了封装分类器中所含的两个已训练校准器。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "C0py9Q6OBRBE"
      },
      "outputs": [],
      "source": [
        "def save_and_visualize_lattice(tfl_estimator):\n",
        "  saved_model_path = tfl_estimator.export_saved_model(\n",
        "      \"/tmp/TensorFlow_Lattice_101/\",\n",
        "      tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "          feature_spec=tf.feature_column.make_parse_example_spec(\n",
        "              feature_columns)))\n",
        "  model_graph = tfl.estimators.get_model_graph(saved_model_path)\n",
        "  figsize(8, 8)\n",
        "  tfl.visualization.draw_model_graph(model_graph)\n",
        "  return model_graph\n",
        "\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7vZ5fShXs504"
      },
      "source": [
        "添加约束后，估算的 CTR 将始终随着平均评分的提高或评论数量的增加而提高。这是通过确保校准器和点阵的单调性来实现的。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RfniRZCHIvfK"
      },
      "source": [
        "### 收益递减\n",
        "\n",
        "[收益递减](https://en.wikipedia.org/wiki/Diminishing_returns)意味着增大某个特征值的边际收益将随着该值的递大而降低。在我们的案例中，我们希望 `num_reviews` 特征遵循此模式，因此我们可以相应地配置其校准器。请注意，我们可以将收益递减分解为两个充分条件：\n",
        "\n",
        "- 校准器单调递增，并且\n",
        "- 校准器为凹函数。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XQrM9BskY-wx"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LSmzHkPUo9u5"
      },
      "source": [
        "请注意添加凹函数约束将如何提升测试指标。预测图也可以更好地还原基本事实。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J6CP2Ovapiu3"
      },
      "source": [
        "### 二维形状约束：信任\n",
        "\n",
        "如果一家餐厅被评为 5 星级餐厅，但只有一两条评论，那么这个评分可能并不可靠（这家餐厅实际上可能不怎么样）；而有数百条评论的餐厅被评为 4 星级餐厅，那么这个评分的可靠性就高得多（在这种情况下，这家餐厅可能不错）。可以看到，餐厅的评论数量会影响我们对餐厅平均评分的信任度。\n",
        "\n",
        "我们可以使用 TFL 信任约束来告知模型，一个特征值越大（或越小）表示另一个特征的可靠或可信度越高。在特征配置中设置 `reflects_trust_in` 配置即可实现这一目标。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OA14j0erm6TJ"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            # Larger num_reviews indicating more trust in avg_rating.\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "model_graph = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "puvP9X8XxyRV"
      },
      "source": [
        "下图展示了已训练的点阵函数。基于信任约束的强制性限制，我们期望校准的 `num_reviews` 值越大，相对于校准的 `avg_rating` 的斜率就越高，进而使点阵输出的跨度更大。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "RounEQebxxnA"
      },
      "outputs": [],
      "source": [
        "lat_mesh_n = 12\n",
        "lat_mesh_x, lat_mesh_y = tfl.test_utils.two_dim_mesh_grid(\n",
        "    lat_mesh_n**2, 0, 0, 1, 1)\n",
        "lat_mesh_fn = tfl.test_utils.get_hypercube_interpolation_fn(\n",
        "    model_graph.output_node.weights.flatten())\n",
        "lat_mesh_z = [\n",
        "    lat_mesh_fn([lat_mesh_x.flatten()[i],\n",
        "                 lat_mesh_y.flatten()[i]]) for i in range(lat_mesh_n**2)\n",
        "]\n",
        "trust_plt = tfl.visualization.plot_outputs(\n",
        "    (lat_mesh_x, lat_mesh_y),\n",
        "    {\"Lattice Lookup\": lat_mesh_z},\n",
        "    figsize=(6, 6),\n",
        ")\n",
        "trust_plt.title(\"Trust\")\n",
        "trust_plt.xlabel(\"Calibrated avg_rating\")\n",
        "trust_plt.ylabel(\"Calibrated num_reviews\")\n",
        "trust_plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SKe3UHX6pUjw"
      },
      "source": [
        "### 平滑校准器\n",
        "\n",
        "现在我们来看一下 `avg_rating` 的校准器。尽管单调递增，但其斜率的变化却唐突且难以解释。这表明我们可能要考虑使用 `regularizer_configs` 中的正则化器设置平滑此校准器。\n",
        "\n",
        "在此，我们应用了 `wrinkle` 正则化器来减少曲率的变化。您也可以使用 `laplacian` 正则化器展平校准器，以及使用 `hessian` 正则化器使其更加线性。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qxFHH3hSpWfq"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "        )\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_two_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HHpp4goLvuPi"
      },
      "source": [
        "校准器现已变得平滑，总体估算的 CTR 可更准确地符合实际情况。这在测试指标和等高线图中均有体现。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pSUd6aFlpYz4"
      },
      "source": [
        "### 分类校准的部分单调性\n",
        "\n",
        "到目前为止，我们一直是仅在模型中使用两个数字特征。在此，我们将使用分类校准层添加第三个特征。我们还是从设置用于绘图和指标计算的辅助函数开始。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5tLDKwTmjrLw"
      },
      "outputs": [],
      "source": [
        "def analyze_three_d_estimator(estimator, name):\n",
        "  # Extract validation metrics.\n",
        "  metric = estimator.evaluate(input_fn=val_input_fn)\n",
        "  print(\"Validation AUC: {}\".format(metric[\"auc\"]))\n",
        "  metric = estimator.evaluate(input_fn=test_input_fn)\n",
        "  print(\"Testing AUC: {}\".format(metric[\"auc\"]))\n",
        "\n",
        "  def three_d_pred(avg_ratings, num_reviews, dollar_rating):\n",
        "    results = estimator.predict(\n",
        "        tf.compat.v1.estimator.inputs.pandas_input_fn(\n",
        "            x=pd.DataFrame({\n",
        "                \"avg_rating\": avg_ratings,\n",
        "                \"num_reviews\": num_reviews,\n",
        "                \"dollar_rating\": dollar_rating,\n",
        "            }),\n",
        "            shuffle=False,\n",
        "        ))\n",
        "    return [x[\"logistic\"][0] for x in results]\n",
        "\n",
        "  figsize(11, 22)\n",
        "  plot_fns([(\"{} Estimated CTR\".format(name), three_d_pred),\n",
        "            (\"CTR\", click_through_rate)],\n",
        "           split_by_dollar=True) "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CnPiqf4rq6kJ"
      },
      "source": [
        "要包含第三个特征 `dollar_rating`，请回想一下，分类特征在 TFL 中需要稍作不同处理，无论是作为特征列还是作为特征配置。在此，我们强制应用部分单调性约束，即在所有其他输入均固定不变的情况下，“DD”餐厅的输出应大于“D”餐厅的输出。这是使用特征配置中的 `monotonicity` 设置完成的。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m-w7iGEEpgGt"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "    tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "        \"dollar_rating\",\n",
        "        vocabulary_list=[\"D\", \"DD\", \"DDD\", \"DDDD\"],\n",
        "        dtype=tf.string,\n",
        "        default_value=0),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    feature_configs=[\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"num_reviews\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_convexity=\"concave\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "            reflects_trust_in=[\n",
        "                tfl.configs.TrustConfig(\n",
        "                    feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"avg_rating\",\n",
        "            lattice_size=2,\n",
        "            monotonicity=\"increasing\",\n",
        "            pwl_calibration_num_keypoints=20,\n",
        "            regularizer_configs=[\n",
        "                tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "            ],\n",
        "        ),\n",
        "        tfl.configs.FeatureConfig(\n",
        "            name=\"dollar_rating\",\n",
        "            lattice_size=2,\n",
        "            pwl_calibration_num_keypoints=4,\n",
        "            # Here we only specify one monotonicity:\n",
        "            # `D` resturants has smaller value than `DD` restaurants\n",
        "            monotonicity=[(\"D\", \"DD\")],\n",
        "        ),\n",
        "    ])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_three_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gdIzhYL79_Pp"
      },
      "source": [
        "此分类校准器展示了模型输出的优先级：DD > D > DDD > DDDD，这与我们的设置一致。请注意，还有一列用于缺失值。尽管我们的训练和测试数据中没有缺失特征，但当下游模型使用期间出现值缺失的情况时，该模型可提供对缺失值的推算。\n",
        "\n",
        "在此，我们还将绘制以 `dollar_rating` 为条件的此模型的预测 CTR。请注意，我们需要的所有约束在每个切片内都要得到满足。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rh0H2b6l_rwZ"
      },
      "source": [
        "### 输出校准"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KPb2ri4e7HXF"
      },
      "source": [
        "对于到目前为止我们训练的所有 TFL 模型，点阵层（在模型图中示为“Lattice”）均直接输出模型预测。有时，我们不确定是否应对点阵输出进行重新调整以激发模型输出：\n",
        "\n",
        "- 特征为 $log$ 计数，而标签为计数。\n",
        "- 点阵被配置为仅包含少量顶点，但标签分布却相对复杂。\n",
        "\n",
        "在这些情况下，我们可以在点阵输出和模型输出之间添加另一个校准器，以提高模型的灵活性。在此，我们向刚刚构建的模型添加一个带有 5 个关键点的校准器层。我们还会为输出校准器添加一个正则化器，使函数保持平滑。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k5Sg_gUj_0i4"
      },
      "outputs": [],
      "source": [
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(\"num_reviews\"),\n",
        "    tf.feature_column.numeric_column(\"avg_rating\"),\n",
        "    tf.feature_column.categorical_column_with_vocabulary_list(\n",
        "        \"dollar_rating\",\n",
        "        vocabulary_list=[\"D\", \"DD\", \"DDD\", \"DDDD\"],\n",
        "        dtype=tf.string,\n",
        "        default_value=0),\n",
        "]\n",
        "model_config = tfl.configs.CalibratedLatticeConfig(\n",
        "    output_calibration=True,\n",
        "    output_calibration_num_keypoints=5,\n",
        "    regularizer_configs=[\n",
        "        tfl.configs.RegularizerConfig(name=\"output_calib_wrinkle\", l2=0.1),\n",
        "    ],\n",
        "    feature_configs=[\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name=\"num_reviews\",\n",
        "        lattice_size=2,\n",
        "        monotonicity=\"increasing\",\n",
        "        pwl_calibration_convexity=\"concave\",\n",
        "        pwl_calibration_num_keypoints=20,\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "        ],\n",
        "        reflects_trust_in=[\n",
        "            tfl.configs.TrustConfig(\n",
        "                feature_name=\"avg_rating\", trust_type=\"edgeworth\"),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name=\"avg_rating\",\n",
        "        lattice_size=2,\n",
        "        monotonicity=\"increasing\",\n",
        "        pwl_calibration_num_keypoints=20,\n",
        "        regularizer_configs=[\n",
        "            tfl.configs.RegularizerConfig(name=\"calib_wrinkle\", l2=1.0),\n",
        "        ],\n",
        "    ),\n",
        "    tfl.configs.FeatureConfig(\n",
        "        name=\"dollar_rating\",\n",
        "        lattice_size=2,\n",
        "        pwl_calibration_num_keypoints=4,\n",
        "        # Here we only specify one monotonicity:\n",
        "        # `D` resturants has smaller value than `DD` restaurants\n",
        "        monotonicity=[(\"D\", \"DD\")],\n",
        "    ),\n",
        "])\n",
        "tfl_estimator = tfl.estimators.CannedClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    model_config=model_config,\n",
        "    feature_analysis_input_fn=feature_analysis_input_fn,\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),\n",
        "    config=tf.estimator.RunConfig(tf_random_seed=42),\n",
        ")\n",
        "tfl_estimator.train(input_fn=train_input_fn)\n",
        "analyze_three_d_estimator(tfl_estimator, \"TF Lattice\")\n",
        "_ = save_and_visualize_lattice(tfl_estimator)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TLOGDrYY0hH7"
      },
      "source": [
        "最终的测试指标和统计图展示了使用常识约束如何帮助模型避免意外行为并针对整个输入空间实现更好的推断。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "shape_constraints.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
